-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: Script mode support for Estimator class #2834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Script mode support for Estimator class #2834
Conversation
Codecov Report
@@ Coverage Diff @@
## master-jumpstart #2834 +/- ##
=================================================
Coverage 89.16% 89.17%
=================================================
Files 185 185
Lines 16047 16069 +22
=================================================
+ Hits 14308 14329 +21
- Misses 1739 1740 +1
Continue to review full report at Codecov.
|
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you be able to add unit tests to verify that:
git_config
is also handled correctly and similarly b7DummyFramework
andEstimator
hyperparameters
are handled correctly (if not covered already)
src/sagemaker/estimator.py
Outdated
@@ -437,6 +581,21 @@ def _get_or_create_name(self, name=None): | |||
self._ensure_base_job_name() | |||
return name_from_base(self.base_job_name) | |||
|
|||
@staticmethod | |||
def _json_encode_hyperparameters(hyperparameters): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know you are just refactoring, but is it possible to add args & return typings?
src/sagemaker/estimator.py
Outdated
self._prepare_rules() | ||
self._prepare_debugger_for_training() | ||
self._prepare_profiler_for_training() | ||
|
||
def _script_mode_hyperparam_update(self, code_dir, script): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment.
src/sagemaker/estimator.py
Outdated
|
||
self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(hyperparams)) | ||
|
||
def _stage_user_code_in_s3(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment.
src/sagemaker/estimator.py
Outdated
self.uploaded_code = self._stage_user_code_in_s3() | ||
code_dir = self.uploaded_code.s3_prefix | ||
script = self.uploaded_code.script_name | ||
def _script_mode_hyperparam_update(self, code_dir, script): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment.
src/sagemaker/estimator.py
Outdated
code_dir (str): The directory hosting the training scripts. | ||
script (str): The relative filepath of the training entry-point script. | ||
""" | ||
hyperparams = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typing please:
hyperparams: Dict[str, str] = {}
tests/unit/test_estimator.py
Outdated
@patch("sagemaker.estimator.Estimator._stage_user_code_in_s3") | ||
def test_script_mode_estimator(patched_stage_user_code, sagemaker_session): | ||
patched_stage_user_code.return_value = UploadedCode( | ||
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: are you trying to be consistent with the rest of the module when you use the ""%(*args) pattern?
If not, could you please use f-string?
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
1441e9f
to
7295190
Compare
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few non-blocking comments/suggestions.
src/sagemaker/estimator.py
Outdated
try to use either CodeCommit credential helper or local | ||
credential storage for authentication. | ||
hyperparameters (dict): Dictionary containing the hyperparameters to | ||
initialize this estimator with. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Default: None
)
src/sagemaker/estimator.py
Outdated
If not specified, the default ``code location`` is s3://output_bucket/job-name/. | ||
entry_point (str): Path (absolute or relative) to the local Python | ||
source file which should be executed as the entry point to | ||
training. If ``source_dir`` is specified, then ``entry_point`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Default: None
)
src/sagemaker/estimator.py
Outdated
@@ -437,6 +582,21 @@ def _get_or_create_name(self, name=None): | |||
self._ensure_base_job_name() | |||
return name_from_base(self.base_job_name) | |||
|
|||
@staticmethod | |||
def _json_encode_hyperparameters(hyperparameters: dict) -> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could you use Dict[str, Any]
instead of dict
for both the argument and the return type.
src/sagemaker/estimator.py
Outdated
self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(hyperparams)) | ||
|
||
def _stage_user_code_in_s3(self) -> str: | ||
"""Upload the user training script to s3 and return the location. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ...and return the S3 URI.
code_s3_prefix = "/".join(filter(None, [key_prefix, self._current_job_name, "source"])) | ||
|
||
output_bucket, _ = parse_s3_url(self.output_path) | ||
kms_key = self.output_kms_key if code_bucket == output_bucket else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non-blocking question: i know you are just refactoring this code, but is this a concern? i.e. are we conforming to customer expectation if we do not use the "output" encryption key when the script gets uploaded to a different bucket than the output bucket?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No idea, this section was directly lifted from somewhere else in the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will need to dive into this.
src/sagemaker/estimator.py
Outdated
@@ -2376,6 +2690,10 @@ def _model_entry_point(self): | |||
|
|||
return None | |||
|
|||
def set_hyperparameters(self, **kwargs): | |||
"""Sets hyperparameters.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's avoid repeating function name in docstring. How about:
"""Escape the dict argument as JSON, update the private hyperparameter attribute."""
repack=self.source_dir | ||
and self.entry_point | ||
and not (self.key_prefix or self.git_config), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit for readability, i would vote for:
is_repack = self.source_dir and self.entry_point and not (self.key_prefix or self.git_config)
self._upload_code(deploy_key_prefix, repack=is_repack)
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
Issue #, if available:
Description of changes:
This PR adds script mode support to the
Estimator
andEstimatorBase
classes.This was basically done by adding the parameters
source_dir, git_config, hyperparameters, container_log_level, code_location, entry_point, dependencies
to theEstimator
andEstimatorBase
classes.Testing done:
The changes to the
Estimator
class do not break any existing unit tests. In addition, new unit tests were added to simulate the script mode use case for theEstimator
class, and confirm that the calls tosagemaker.create_training_job()
ands3
are the same for theEstimator
class andFramework
class when both use script mode. A test was also added to ensure that git support works with theEstimator
class. Integration tests will be introduced in a subsequent PR.Merge Checklist
Put an
x
in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your pull request.General
Tests
unique_name_from_base
to create resource names in integ tests (if appropriate)By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.